{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import os\n",
    "import re\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()\n",
    "print(\"gpu num: \", n_gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pre-trained translation model from fairseq.\n",
    "You can always use another middle language or other models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load translation model\n",
    "en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru.single_model', tokenizer='moses', bpe='fastbpe')\n",
    "ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en.single_model', tokenizer='moses', bpe='fastbpe')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')\n",
    "de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "en2ru.cuda()\n",
    "ru2en.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "en2de.cuda()\n",
    "de2en.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './'\n",
    "train_df = pd.read_csv(path+'train.csv', header=None)\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = [v-1 for v in train_df[0]]\n",
    "train_text = [v for v in train_df[2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_text[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split and get our unlabeled training data\n",
    "def train_val_split(labels, n_labeled_per_class, n_labels, seed = 0):\n",
    "    np.random.seed(seed)\n",
    "    labels = np.array(labels)\n",
    "    train_labeled_idxs = []\n",
    "    train_unlabeled_idxs = []\n",
    "    val_idxs = []\n",
    "\n",
    "    for i in range(n_labels):\n",
    "        idxs = np.where(labels == i)[0]\n",
    "        np.random.shuffle(idxs)\n",
    "        train_labeled_idxs.extend(idxs[:n_labeled_per_class])\n",
    "        train_unlabeled_idxs.extend(idxs[n_labeled_per_class : n_labeled_per_class + 10000])\n",
    "        val_idxs.extend(idxs[-3000:])\n",
    "    \n",
    "    np.random.shuffle(train_labeled_idxs)\n",
    "    np.random.shuffle(train_unlabeled_idxs)\n",
    "    np.random.shuffle(val_idxs)\n",
    "    return train_labeled_idxs, train_unlabeled_idxs, val_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(train_labels, 500, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_unlabeled_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs = train_unlabeled_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Back translation process\n",
    "You can tune the temperature in the translation process to control the diversity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# back translate using Russian as middle language\n",
    "def translate_ru(start, end, file_name):\n",
    "    trans_result = {}\n",
    "    for id in tqdm(range(start, end)):\n",
    "        trans_result[idxs[id]] = ru2en.translate(en2ru.translate(train_text[idxs[id]],  sampling = True, temperature = 0.9),  sampling = True, temperature = 0.9)\n",
    "        if id % 500 == 0:\n",
    "            with open(file_name, 'wb') as f:\n",
    "                pickle.dump(trans_result, f)\n",
    "    with open(file_name, 'wb') as f:\n",
    "        pickle.dump(trans_result, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# back translate using German as middle language\n",
    "def translate_de(start, end, file_name):\n",
    "    trans_result = {}\n",
    "    for id in tqdm(range(start, end)):\n",
    "        trans_result[idxs[id]] = de2en.translate(en2de.translate(train_text[idxs[id]],  sampling = True, temperature = 0.9),  sampling = True, temperature = 0.9)\n",
    "        if id % 500 == 0:\n",
    "            with open(file_name, 'wb') as f:\n",
    "                pickle.dump(trans_result, f)\n",
    "    with open(file_name, 'wb') as f:\n",
    "        pickle.dump(trans_result, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "translate_de(0,100000, 'de_1.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
